//+------------------------------------------------------------------+
//|                                     ONNX challenges REALTIME.mq5 |
//|                                     Copyright 2023, Omega Joctan |
//|                        https://www.mql5.com/en/users/omegajoctan |
//+------------------------------------------------------------------+
#property copyright "Copyright 2023, Omega Joctan"
#property link      "http://fxalgebra.com"
#property version   "1.00"

#include <Timeseries Deep Learning\onnx.mqh>
#include <Timeseries Deep Learning\tsdataprocessor.mqh>
#include <preprocessing.mqh>
#include <Dimensionality Reduction\PCA.mqh>

CPCA *pca;

#resource "\\Files\\model.eurusd.D1.onnx" as uchar lstm_model_data[]
#resource "\\Files\\model.eurusd.D1.PCA.onnx" as uchar lstm_model_pca[]

#resource "\\Files\\EURUSD-SCALER\\mean.bin" as double standardization_scaler_mean[];
#resource "\\Files\\EURUSD-SCALER\\std.bin" as double standardization_scaler_std[];

#resource "\\Files\\EURUSD-PCA-SCALER\\mean.bin" as double standardization_pca_scaler_mean[];
#resource "\\Files\\EURUSD-PCA-SCALER\\std.bin" as double standardization_pca_scaler_std[];

#resource "\\Files\\EURUSD-PCA\\components-matrix.bin" as double pca_comp_matrix[];
#resource "\\Files\\EURUSD-PCA\\mean.bin" as double pca_mean[];

#define MAGIC_NUMBER 14042024

#include <Trade\Trade.mqh>
#include <Trade\PositionInfo.mqh>
CTrade m_trade;
CPositionInfo m_position;

input int time_step_ = 7;
input bool use_pca = true;

//it is very important the time step value matches the one used during training in  a python script

CONNX onnx;
StandardizationScaler *scaler;
CTSDataProcessor ts_dataprocessor;
CTensors *ts_data_tensor;

MqlRates rates[];
vector classes_ = {0,1};
int prev_bars = 0;
MqlTick ticks;
double min_lot = 0;
//+------------------------------------------------------------------+
//| Expert initialization function                                   |
//+------------------------------------------------------------------+
int OnInit()
  {
//---
   
   if (use_pca)
    {
    if (!onnx.Init(lstm_model_pca))
      return INIT_FAILED;
    }
   else
     {
       if (!onnx.Init(lstm_model_data))
         return INIT_FAILED;
     }
   
   if (use_pca)   
    {
      scaler = new StandardizationScaler(standardization_pca_scaler_mean, standardization_pca_scaler_std); //laoding the saved scaler
      pca = new CPCA(pca_mean, pca_comp_matrix);
    }  
   else
      scaler = new StandardizationScaler(standardization_scaler_mean, standardization_scaler_std); //laoding the saved scaler
    
//---
   
   m_trade.SetExpertMagicNumber(MAGIC_NUMBER);
   m_trade.SetDeviationInPoints(100);
   m_trade.SetTypeFillingBySymbol(Symbol());
   m_trade.SetMarginMode();
   
   min_lot = SymbolInfoDouble(Symbol(), SYMBOL_VOLUME_MIN);
   
   return(INIT_SUCCEEDED);
  }
//+------------------------------------------------------------------+
//| Expert deinitialization function                                 |
//+------------------------------------------------------------------+
void OnDeinit(const int reason)
  {
//---
   if (CheckPointer(pca)!=POINTER_INVALID)
      delete pca;
   if (CheckPointer(scaler)!=POINTER_INVALID)
      delete scaler;
  }
//+------------------------------------------------------------------+
//| Expert tick function                                             |
//+------------------------------------------------------------------+
void OnTick()
  {
//---
   
   if (!MQLInfoInteger(MQL_TESTER)) //if we are live trading consider new bar event
      if (!isnewBar(PERIOD_CURRENT))
        return;
      
   if (CopyRates(Symbol(), PERIOD_D1, 1, time_step_, rates)<-1)
     {
       printf("Failed to collect data Err=%d",GetLastError());
       return;
     }

   matrix data(time_step_, 3);
   for (int i=0; i<time_step_; i++) //Get the independent values and save them to a matrix
     {
       data[i][0] = rates[i].open;
       data[i][1] = rates[i].high;
       data[i][2] = rates[i].low;
     }
   
   ts_data_tensor = ts_dataprocessor.extract_timeseries_data(data, time_step_);  //process the new data into timeseries 
   
   data = ts_data_tensor.Get(0); //This tensor contains only one matrix for the recent latest bars thats why we find it at the index 0
   
   if (use_pca)
    data = pca.transform(data);
    
   data = scaler.transform(data); //Transform the new data 
   
   int signal = onnx.predict_bin(data, classes_);
   
   Comment("LSTM trade signal: ",signal);

//--- Open trades based on Signals
   
   SymbolInfoTick(Symbol(), ticks);
   if (signal==1) 
    {
      if (!PosExists(POSITION_TYPE_BUY))
        m_trade.Buy(min_lot,Symbol(), ticks.ask);
      else
       {
         PosClose(POSITION_TYPE_BUY); 
         PosClose(POSITION_TYPE_SELL); 
       } 
    }
   else
     {
      if (!PosExists(POSITION_TYPE_SELL))
        m_trade.Sell(min_lot,Symbol(), ticks.bid);
      else
       {
          PosClose(POSITION_TYPE_SELL); 
          PosClose(POSITION_TYPE_BUY); 
       }
     } 
  }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
bool isnewBar(ENUM_TIMEFRAMES TF)
 {
   if (prev_bars == 0)
      prev_bars = Bars(Symbol(), TF);
      
   
   if (prev_bars != Bars(Symbol(), TF))
    { 
      prev_bars = Bars(Symbol(), TF);
      return true;
    }
    
  return false;
 }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
bool PosExists(ENUM_POSITION_TYPE type)
 {
 for (int i=PositionsTotal()-1; i>=0; i--)
   if (m_position.SelectByIndex(i))
      if (m_position.Symbol()==Symbol() && m_position.Magic() == MAGIC_NUMBER && m_position.PositionType()==type)
         return true;
         
 return false;
 }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
void PosClose(ENUM_POSITION_TYPE type)
 {
 for (int i=PositionsTotal()-1; i>=0; i--)
   if (m_position.SelectByIndex(i))
      if (m_position.Symbol()==Symbol() && m_position.Magic() == MAGIC_NUMBER && m_position.PositionType()==type)
         if (!m_trade.PositionClose(m_position.Ticket()))
           printf("Failed to close position %d Err=%s",m_position.Ticket(),m_trade.ResultRetcodeDescription());
 }
//+------------------------------------------------------------------+
//|                                                                  |
//+------------------------------------------------------------------+
